import time

import os
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
import pickle
import collections
from collections import defaultdict
import matplotlib.pyplot as plt
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Model, load_model
from keras.layers import Input, Dense, Conv2D, LeakyReLU, Dropout, Flatten, MaxPooling2D, GlobalAveragePooling2D
from keras.layers import BatchNormalization, Embedding, Reshape, Activation
from keras.layers import Concatenate, Conv2DTranspose, multiply, UpSampling2D
from keras.initializers import RandomNormal
from keras.optimizers import Adam
from keras.utils import Progbar
from keras.metrics import *
from keras import backend as K
import nalp.utils.constants as c
from keras.layers import concatenate
import cv2

# https://github.com/leoclementliao/cgan/blob/master/cWGAN_collision.ipynb
def covidGenerator(noise_dim, gen_dim, temperature):
    # inp_condition = Input(shape=[condition_dim, ], name='condition_G')
    inp_noise = Input(shape=[noise_dim, ], name='noise')
    # temperature = Input(shape=[1,])

    print('--->', temperature)
    # X = concatenate([inp_condition, inp_noise], axis=1)
    X = inp_noise

    X = Dense(32, activation='relu')(X)
    X = Dense(32, activation='relu')(X)
    X = Dense(16, activation='relu')(X)
    output = Dense(gen_dim)(X)

    # Gumbel softmax
    #Returns: (Tuple[tf.Tensor, tf.Tensor]): Gumbel-Softmax output and its argmax token.
    # https://github.com/gugarosa/nalp/blob/master/nalp/models/layers/gumbel_softmax.py
    uniform_dist = tf.random.uniform(tf.shape(output), 0, 1)
    gumbel_dist = -1 * tf.math.log(-1 * tf.math.log(uniform_dist + c.EPSILON) + c.EPSILON)
    axis=-1
    x = output +  gumbel_dist
    x = tf.nn.softmax(x / temperature, axis)
    y = tf.stop_gradient(tf.argmax(x, axis , tf.int32))
    last= x
    print('inside model', temperature)

    #### if not gumbel
    # last= tf.nn.sigmoid(output)  # since input to



    # return tf.keras.Model(inputs=[inp_condition, inp_noise], outputs=last, name='Generator')
    return tf.keras.Model(inputs=[inp_noise], outputs=last, name='Generator')


def covidDiscriminator(gen_dim):
    # inp_condition = Input(shape=[condition_dim, ], name='condition_D')
    inp_target = tf.keras.layers.Input(shape=[gen_dim, ], name='target')
    # X = concatenate([inp_condition, inp_target], axis=1)
    X=inp_target

    X = Dense(32, activation='relu')(X)
    X = Dense(32, activation='relu')(X)
    X = Dense(16, activation='relu')(X)
    last = Dense(1)(X)

    # return tf.keras.Model(inputs=[inp_condition, inp_target], outputs=last, name='Discriminator')
    return tf.keras.Model(inputs=[inp_target], outputs=last, name='Discriminator')


# https://github.com/giocoal/CXR-ACGAN-chest-xray-generator-covid19-pneumonia/tree/main
def xrayGenerator(latent_dim = 100, n_classes = 3):
    # Initialize RandomNormal with mean = 0.0 and stddev = 0.02
    # init = RandomNormal(mean = 0.0, stddev = 0.02)

    ### Input 1: class label input ###

    # Generator take integer class label as input
    label_input = Input(shape = (1,))
    # print(label_input.shape)

    # Embedding layer: to convert class label integer to a vector of size 100
    y = Embedding(n_classes, 100)(label_input)
    # print('Embedding Layer: ', y.shape)

    # Dense layer with 7 x 7 units: to convert the vector to a 7 x 7 x 1 tensor
    n_nodes = 7 * 7
    y = Dense(n_nodes, kernel_initializer = RandomNormal(mean = 0.0, stddev = 0.02))(y)
    # print('Dense 1: ', y.shape)
    y = Reshape((7, 7 ,1))(y)
    print('reshape(final y shape): ', y.shape)

    ### Input 2: generator noise input ###

    # A latent_dim-dimensional vector is sampled from a normal distribution
    # with mean = 0.0 and stddev = 0.02
    generator_input = Input(shape=(latent_dim,))

    # Noise vector is passed through a dense layer with 1024 * 7 * 7 units
    # to produce a 7 x 7 x 1024 tensor
    n_nodes = 1024 * 7 * 7
    gen = Dense(n_nodes,
                kernel_initializer=RandomNormal(mean=0.0, stddev=0.02))(generator_input)
    gen = Activation('relu')(gen)
    gen = Reshape((7, 7, 1024))(gen)
    print('Generator noise input: ', gen.shape)

    ### Concatenate both the inputs ###
    # The output tensors are then concatenated to produce a 7 × 7 × 1025 tensor.
    merge = Concatenate()([gen, y])
    print('Concatenate(generator noise input and y: ', merge.shape)

    ### Upsampling ###
    # four successive transposed convolutions
    # to produce tensors with dimensions 14 × 14 × 512, 28 × 28 × 256, 56 × 56 × 128 and 128 × 128 × 3, respectively.

    # (None, 7, 7, 1024) --> (None, 14, 14, 512)
    gen = Conv2DTranspose(512, kernel_size=(5, 5), strides=(2, 2), padding="same",
                          kernel_initializer=RandomNormal(mean=0.0, stddev=0.02))(merge)
    gen = BatchNormalization(momentum=0.9)(gen)
    gen = Activation("relu")(gen)
    print("(None, 7, 7, 1024) -> (None, 14, 14, 512): ", gen.shape)

    # (None, 14, 14, 512)  --> (None, 28, 28, 256)
    gen = Conv2DTranspose(256, kernel_size=(5, 5), strides=(2, 2), padding="same",
                          kernel_initializer=RandomNormal(mean=0.0, stddev=0.02))(gen)
    gen = BatchNormalization(momentum=0.9)(gen)
    gen = Activation("relu")(gen)
    print('(None, 14, 14, 512) -> (None, 28, 28, 256): ', gen.shape)

    # (None, 28, 28, 256) --> (None, 56, 56, 128)
    gen = Conv2DTranspose(128, kernel_size=(5, 5), strides=(2, 2), padding="same",
                          kernel_initializer=RandomNormal(mean=0.0, stddev=0.02))(gen)
    gen = BatchNormalization(momentum=0.9)(gen)
    gen = Activation("relu")(gen)
    print('(None, 28, 28, 256) -> (None, 56, 56, 128): ', gen.shape)

    # (None, 56, 56, 128) --> (None, 112, 112, 3)
    gen = Conv2DTranspose(3, kernel_size=(5, 5), strides=(2, 2), padding="same",
                          kernel_initializer=RandomNormal(mean=0.0, stddev=0.02))(gen)
    out_layer = Activation("tanh")(gen)
    print("(None, 56, 56, 128) -> (None, 112, 112, 3): ", out_layer.shape)

    # The final output from the generator is an fake image X of dimension 112 × 112 × 3
    model = Model(inputs=[generator_input, label_input], outputs=out_layer)
    return model


def xrayDiscriminator(input_shape=(112, 112, 3), n_classes=3):
    # Define the weight initialization method
    # init = RandomNormal(mean = 0.0, stddev = 0.02)

    # Define the convolutional block to be used multiple times
    def conv_block(input_layer, filter_size, stride):
        # Apply Conv2D with specified filter size, kernel size, padding, strides and weight initialization
        x = Conv2D(filter_size, kernel_size=(3, 3), padding='same',
                   strides=stride, kernel_initializer=RandomNormal(mean=0.0, stddev=0.02))(input_layer)
        # Apply BatchNormalization with momentum = 0.9
        x = BatchNormalization(momentum=0.9)(x)
        # Apply LeakyReLU activation with alpha = 0.2
        x = LeakyReLU(alpha=0.2)(x)
        # Apply dropout with rate 0.5
        x = Dropout(0.5)(x)
        # Return the result
        return x

    # Input layer for image
    input_img = Input(shape=input_shape)

    # Apply Conv2D with 32 filters, kernel size (3, 3), strides (1, 1), padding 'same' and weight initialization
    x = Conv2D(32, kernel_size=(3, 3), strides=(1, 1), padding='same',
               kernel_initializer=RandomNormal(mean=0.0, stddev=0.02))(input_img)
    # Apply BatchNormalization with momentum = 0.9
    x = BatchNormalization(momentum=0.9)(x)
    # Apply LeakyReLU activation with alpha = 0.2
    x = LeakyReLU(alpha=0.2)(x)
    # Apply dropout with rate 0.5
    x = Dropout(0.5)(x)

    # Downsample the image to 56 x 56 x 64 using the conv_block
    x = conv_block(x, 64, (2, 2))
    # Downsample the image to 28 x 28 x 128 using the conv_block
    x = conv_block(x, 128, (2, 2))

    # Downsample the image to 14 x 14 x 256 using the conv_block
    x = conv_block(x, 256, (2, 2))
    # Downsample the image to 7 x 7 x 512 using the conv_block
    x = conv_block(x, 512, (2, 2))

    # Flatten the final layer
    features = Flatten()(x)

    # Binary classifier to classify image as fake or real
    fake = Dense(1, activation='sigmoid', name='source')(features)

    # Multi-class classifier to classify image digit
    aux = Dense(n_classes, activation='softmax', name='auxiliary')(features)

    # Create the final model
    # Model has two outputs - fake (binary fake/real prediction) and aux (multi-class digit prediction)
    model = Model(inputs=input_img, outputs=[fake, aux])

    # Return the model
    return model

